{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reflection Workflow for Structured Outputs\n",
    "\n",
    "This notebook walks through setting up a `Workflow` to provide reliable structured outputs through retries and reflection on mistakes.\n",
    "\n",
    "This notebook works best with an open-source LLM, so we will use `Ollama`. If you don't already have Ollama running, visit [https://ollama.com](https://ollama.com) to get started and download the model you want to use. (In this case, we did `ollama pull llama3.1` before running this notebook)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U llama-index llama-index-llms-ollama"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since workflows are async first, this all runs fine in a notebook. If you were running in your own code, you would want to use `asyncio.run()` to start an async event loop if one isn't already running.\n",
    "\n",
    "```python\n",
    "async def main():\n",
    "    <async code>\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    import asyncio\n",
    "    asyncio.run(main())\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Designing the Workflow\n",
    "\n",
    "To validate the structured output of an LLM, we need only two steps:\n",
    "1. Generate the structured output\n",
    "2. Validate that the output is proper JSON\n",
    "\n",
    "The key thing here is that, if the output is invalid, we **loop** until it is, giving error feedback to the next generation.\n",
    "\n",
    "### The Workflow Events\n",
    "\n",
    "To handle these steps, we need to define a few events:\n",
    "1. An event to pass on the generated extraction \n",
    "2. An event to give feedback when the extraction is invalid\n",
    "\n",
    "The other steps will use the built-in `StartEvent` and `StopEvent` events."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.workflow import Event\n",
    "\n",
    "\n",
    "class ExtractionDone(Event):\n",
    "    output: str\n",
    "    passage: str\n",
    "\n",
    "\n",
    "class ValidationErrorEvent(Event):\n",
    "    error: str\n",
    "    wrong_output: str\n",
    "    passage: str"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Item to Extract\n",
    "\n",
    "To prompt our model, lets define a pydantic model we want to extract."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pydantic import BaseModel\n",
    "\n",
    "\n",
    "class Car(BaseModel):\n",
    "    brand: str\n",
    "    model: str\n",
    "    power: int\n",
    "\n",
    "\n",
    "class CarCollection(BaseModel):\n",
    "    cars: list[Car]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The Workflow Itself\n",
    "\n",
    "With our events defined, we can construct our workflow and steps. \n",
    "\n",
    "Note that the workflow automatically validates itself using type annotations, so the type annotations on our steps are very helpful!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "from llama_index.core.workflow import (\n",
    "    Workflow,\n",
    "    StartEvent,\n",
    "    StopEvent,\n",
    "    Context,\n",
    "    step,\n",
    ")\n",
    "from llama_index.llms.ollama import Ollama\n",
    "\n",
    "EXTRACTION_PROMPT = \"\"\"\n",
    "Context information is below:\n",
    "---------------------\n",
    "{passage}\n",
    "---------------------\n",
    "\n",
    "Given the context information and not prior knowledge, create a JSON object from the information in the context.\n",
    "The JSON object must follow the JSON schema:\n",
    "{schema}\n",
    "\n",
    "\"\"\"\n",
    "\n",
    "REFLECTION_PROMPT = \"\"\"\n",
    "You already created this output previously:\n",
    "---------------------\n",
    "{wrong_answer}\n",
    "---------------------\n",
    "\n",
    "This caused the JSON decode error: {error}\n",
    "\n",
    "Try again, the response must contain only valid JSON code. Do not add any sentence before or after the JSON object.\n",
    "Do not repeat the schema.\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "class ReflectionWorkflow(Workflow):\n",
    "    max_retries: int = 3\n",
    "\n",
    "    @step\n",
    "    async def extract(\n",
    "        self, ctx: Context, ev: StartEvent | ValidationErrorEvent\n",
    "    ) -> StopEvent | ExtractionDone:\n",
    "        current_retries = await ctx.store.get(\"retries\", default=0)\n",
    "        if current_retries >= self.max_retries:\n",
    "            return StopEvent(result=\"Max retries reached\")\n",
    "        else:\n",
    "            await ctx.store.set(\"retries\", current_retries + 1)\n",
    "\n",
    "        if isinstance(ev, StartEvent):\n",
    "            passage = ev.get(\"passage\")\n",
    "            if not passage:\n",
    "                return StopEvent(result=\"Please provide some text in input\")\n",
    "            reflection_prompt = \"\"\n",
    "        elif isinstance(ev, ValidationErrorEvent):\n",
    "            passage = ev.passage\n",
    "            reflection_prompt = REFLECTION_PROMPT.format(\n",
    "                wrong_answer=ev.wrong_output, error=ev.error\n",
    "            )\n",
    "\n",
    "        llm = Ollama(\n",
    "            model=\"llama3\",\n",
    "            request_timeout=30,\n",
    "            # Manually set the context window to limit memory usage\n",
    "            context_window=8000,\n",
    "        )\n",
    "        prompt = EXTRACTION_PROMPT.format(\n",
    "            passage=passage, schema=CarCollection.schema_json()\n",
    "        )\n",
    "        if reflection_prompt:\n",
    "            prompt += reflection_prompt\n",
    "\n",
    "        output = await llm.acomplete(prompt)\n",
    "\n",
    "        return ExtractionDone(output=str(output), passage=passage)\n",
    "\n",
    "    @step\n",
    "    async def validate(\n",
    "        self, ev: ExtractionDone\n",
    "    ) -> StopEvent | ValidationErrorEvent:\n",
    "        try:\n",
    "            CarCollection.model_validate_json(ev.output)\n",
    "        except Exception as e:\n",
    "            print(\"Validation failed, retrying...\")\n",
    "            return ValidationErrorEvent(\n",
    "                error=str(e), wrong_output=ev.output, passage=ev.passage\n",
    "            )\n",
    "\n",
    "        return StopEvent(result=ev.output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And thats it! Let's explore the workflow we wrote a bit.\n",
    "\n",
    "- We have one entry point, `extract` (the steps that accept `StartEvent`)\n",
    "- When `extract` finishes, it emits a `ExtractionDone` event\n",
    "- `validate` runs and confirms the extraction:\n",
    "  - If its ok, it emits `StopEvent` and halts the workflow\n",
    "  - If nots not, it returns a `ValidationErrorEvent` with information about the error\n",
    "- Any `ValidationErrorEvent` emitted will trigger the loop, and `extract` runs again!\n",
    "- This continues until the structured output is validated"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the Workflow!\n",
    "\n",
    "**NOTE:** With loops, we need to be mindful of runtime. Here, we set a timeout of 120s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running step extract\n",
      "Step extract produced event ExtractionDone\n",
      "Running step validate\n",
      "Validation failed, retrying...\n",
      "Step validate produced event ValidationErrorEvent\n",
      "Running step extract\n",
      "Step extract produced event ExtractionDone\n",
      "Running step validate\n",
      "Step validate produced event StopEvent\n"
     ]
    }
   ],
   "source": [
    "w = ReflectionWorkflow(timeout=120, verbose=True)\n",
    "\n",
    "# Run the workflow\n",
    "ret = await w.run(\n",
    "    passage=\"I own two cars: a Fiat Panda with 45Hp and a Honda Civic with 330Hp.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{ \"cars\": [ { \"brand\": \"Fiat\", \"model\": \"Panda\", \"power\": 45 }, { \"brand\": \"Honda\", \"model\": \"Civic\", \"power\": 330 } ] }\n"
     ]
    }
   ],
   "source": [
    "print(ret)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama-index-cDlKpkFt-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
